class Dataset:
def __init__(self, data_path):
self.data_path = data_path
print(data_path)
def show_image(self, image_name, title):
image = mpimg.imread(image_name)
plt.figure()
plt.title(title)
plt.imshow(image)
def show_images(self, image_names, images_path, title):
for image_name in image_names:
self.show_image(os.path.join(images_path, image_name), title)
def show_random_partition_class(self, partition, data_class, n, val_type='context'):
if partition == 'train':
images_path = os.path.join(self.data_path, 'ROBIN-cls-{}'.format(partition), data_class)
image_names = os.listdir(images_path)
if n<len(image_names):
self.show_images(random.sample(image_names, n), images_path, data_class)
else:
self.show_images(image_names, images_path, data_class)
elif partition == 'val':
images_path = os.path.join(self.data_path, 'ROBIN-cls-{}'.format(partition), val_type, data_class)
image_names = os.listdir(images_path)
if n<len(image_names):
self.show_images(random.sample(image_names, n), images_path, data_class+'_'+val_type)
else:
self.show_images(image_names, images_path, data_class+'_'+val_type)
def visualize_train_classes(self, n):
images_path = os.path.join(self.data_path, 'ROBIN-cls-train')
class_names = os.listdir(images_path)
print(class_names)
for class_name in class_names:
if class_name.startswith('.'):
continue
self.show_random_partition_class('train', class_name, n)
def visualize_val_classes(self, n):
types_path = os.path.join(self.data_path, 'ROBIN-cls-val')
type_names = os.listdir(types_path)
for type_name in type_names:
if type_name.startswith('.'):
continue
images_path = os.path.join(self.data_path, 'ROBIN-cls-val', type_name)
class_names = os.listdir(images_path)
for class_name in class_names:
if class_name.startswith('.'):
continue
self.show_random_partition_class('val', class_name, n, type_name)